{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6f796bd0",
   "metadata": {},
   "source": [
    "### Memory distortions in the extended model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94923d7f",
   "metadata": {},
   "source": [
    "#### Installation:\n",
    "\n",
    "Tested with tensorflow 2.11.0 and Python 3.10.9."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65aa68b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt\n",
    "!pip install importlib-metadata==4.13.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2672f4a",
   "metadata": {},
   "source": [
    "#### Imports and set up:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b26b1a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from end_to_end import *\n",
    "from extended_model import *\n",
    "from extended_model_utils import *\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import tensorflow_datasets as tfds\n",
    "import pandas \n",
    "import pickle\n",
    "import pandas as pd\n",
    "from PIL import Image, ImageFilter, ImageChops, ImageDraw\n",
    "from scipy.stats import spearmanr, pearsonr\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "dims = (64,64,3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff0e459",
   "metadata": {},
   "source": [
    "#### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c56a9aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'shapes3d'\n",
    "ds = tfds.load(dataset, split='train', shuffle_files=False, data_dir='./data/')\n",
    "ds_info = tfds.builder(dataset).info\n",
    "df = tfds.as_dataframe(ds, ds_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c18b74c",
   "metadata": {},
   "source": [
    "#### Load previously trained VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0290c564",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# set tensorflow and numpy random seeds to make outputs reproducible\n",
    "tf.random.set_seed(321)\n",
    "np.random.seed(321)\n",
    "\n",
    "# set tensorflow image format to arrays of dim (width, height, num_channels)\n",
    "tf.keras.backend.set_image_data_format('channels_last')\n",
    "\n",
    "# set number of epochs (actually the max number as early stopping is enabled)\n",
    "generative_epochs = 50\n",
    "# set number of images\n",
    "num_ims = 10000\n",
    "\n",
    "params = {'hopfield_beta': 20,\n",
    "          'latent_dim': 20}\n",
    "\n",
    "net, vae = run_end_to_end(dataset='shapes3d', \n",
    "                          generative_epochs=generative_epochs, \n",
    "                          num=num_ims, \n",
    "                          latent_dim=params['latent_dim'], \n",
    "                          interpolate=False,\n",
    "                          do_vector_arithmetic=False,\n",
    "                          kl_weighting=1,\n",
    "                          lr = 0.001,\n",
    "                          hopfield_beta=params['hopfield_beta'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db4dc62a",
   "metadata": {},
   "source": [
    "#### Testing distortions at different thresholds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5981c5c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "thresholds = [0.01, 0.05, 0.1, 0.15, 0.2]\n",
    "errors = []\n",
    "counts = []\n",
    "\n",
    "for t in thresholds:\n",
    "    err, pixel_count, errors_l, counts_l = test_extended_model(vae=vae, dataset=dataset, threshold=t,\n",
    "                                                           latent_dim=params['latent_dim'],\n",
    "                                                           return_errors_and_counts=True, n=100, beta=100)\n",
    "    errors.append(errors_l)\n",
    "    counts.append(counts_l)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8034391-000d-42b1-9610-cb727a9dca2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_distortions(thresholds, errors, counts):\n",
    "    fig, ax = plt.subplots(figsize=(4,3))\n",
    "    ax2 = ax.twinx()\n",
    "\n",
    "    ax.set_ylabel(\"Reconstruction error\", color='red')\n",
    "    ax2.set_ylabel('No. of sensory feature units', color='blue')\n",
    "\n",
    "    errors_mean = np.mean(errors, axis=1)\n",
    "    errors_sem = np.std(errors, axis=1) / np.sqrt(len(errors[0]))\n",
    "\n",
    "    counts_mean = np.mean(counts, axis=1)\n",
    "    counts_sem = np.std(counts, axis=1) / np.sqrt(len(counts[0]))\n",
    "\n",
    "    ax.errorbar(thresholds, errors_mean, yerr=errors_sem, color='red', fmt='-o', capsize=5, markersize=4)\n",
    "    ax2.errorbar(thresholds, counts_mean, yerr=counts_sem, color='blue', fmt='-o', capsize=5, markersize=4)\n",
    "\n",
    "    plt.title('Reconstruction error and memory \\ndimension at different thresholds')\n",
    "\n",
    "    ax.set_xlabel('Threshold')\n",
    "    plt.savefig(f'hybrid_model/error_thresholds_{dataset}.pdf', bbox_inches = \"tight\")\n",
    "    return fig, thresholds, errors_mean, counts_mean\n",
    "\n",
    "fig, thresholds, errors_mean, counts_mean = plot_distortions(thresholds, errors, counts)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88ac11c0-4ffc-4c64-ba00-dc9b428f77ee",
   "metadata": {},
   "source": [
    "#### Statistical analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df7f9413-c4fc-4478-8630-9e257e27735b",
   "metadata": {},
   "outputs": [],
   "source": [
    "coef_errors, p_errors = spearmanr(thresholds, errors_mean)\n",
    "coef_counts, p_counts = spearmanr(thresholds, counts_mean)\n",
    "print(f\"Spearman correlation (thresholds vs errors_mean): coef = {coef_errors}, p-value = {p_errors}\")\n",
    "print(f\"Spearman correlation (thresholds vs counts_mean): coef = {coef_counts}, p-value = {p_counts}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7476c51",
   "metadata": {},
   "source": [
    "#### Simulating semantic distortions of episodic memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0fbea10-614a-45a7-90ce-a1c5d7ee4d5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# filtered_df = df[\n",
    "#     (df['label_floor_hue'] == 1) &\n",
    "#     (df['label_object_hue'] == 2) &\n",
    "#     (df['label_orientation'] == 7) &\n",
    "#     (df['label_scale'] == 0) &\n",
    "#     (df['label_wall_hue'] == 3)\n",
    "# ]\n",
    "\n",
    "# filtered_df['image'] = filtered_df['image']/255\n",
    "\n",
    "# with open('subset.pickle', 'wb') as handle:\n",
    "#     pickle.dump(filtered_df, handle, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11489d74",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('subset.pickle', 'rb') as handle:\n",
    "    filtered_df = pickle.load(handle)\n",
    "\n",
    "filtered_df = filtered_df.reset_index(drop=True)\n",
    "filtered_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e580e4de",
   "metadata": {},
   "outputs": [],
   "source": [
    "attributes_to_match = ['label_floor_hue', 'label_object_hue', 'label_orientation',\n",
    "                       'label_scale', 'label_wall_hue']\n",
    "\n",
    "def get_diff_shape_im(df, i, shape_id):\n",
    "    attribute_values = df.iloc[i][attributes_to_match]\n",
    "    filtered_df = df[(df[attributes_to_match] == attribute_values).all(axis=1)]\n",
    "    new_im = filtered_df.loc[filtered_df['label_shape'] == shape_id, 'image'].iloc[0]\n",
    "    return new_im"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbce212d-7cc4-42ed-931c-d26db4410cc0",
   "metadata": {},
   "source": [
    "The following functions are used to produce ambiguous stimuli for modelling the Carmichael effect. Gaussian blur is applied to a  circle encompassing the central object in the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72ca652b-3aad-4a1a-9370-0da00933bf0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_circular_mask(size, centre, radius):\n",
    "    mask = Image.new('L', size, 0)\n",
    "    draw = ImageDraw.Draw(mask)\n",
    "    draw.ellipse((centre[0]-radius, centre[1]-radius, centre[0]+radius, centre[1]+radius), fill=255)\n",
    "    inverted_mask = ImageChops.invert(mask)\n",
    "    return inverted_mask\n",
    "\n",
    "def blur_centre(img, centre=(32, 45), radius=15, blur_strength=3):\n",
    "    # Create the mask\n",
    "    mask = create_circular_mask(img.size, centre, radius)\n",
    "\n",
    "    # Apply the mask to create the inner image\n",
    "    inner = Image.composite(img, img.filter(ImageFilter.GaussianBlur(blur_strength)), mask)\n",
    "\n",
    "    # Paste the inner image onto the original image\n",
    "    result = ImageChops.composite(img, inner, mask)\n",
    "    return result\n",
    "\n",
    "def add_white_square(d, dims, seed=0):\n",
    "    random.seed(deterministic_seed(d))\n",
    "    square_size = int(dims[0]/8)\n",
    "    im1 = Image.fromarray((d * 255).astype(\"uint8\"))\n",
    "    im1 = blur_centre(im1)\n",
    "    im2 = Image.fromarray((np.ones((square_size, square_size, 3)) * 255).astype(\"uint8\"))\n",
    "    Image.Image.paste(im1, im2, (random.randrange(0, dims[0]), random.randrange(0, dims[0])))\n",
    "    return (np.array(im1) / 255).reshape((dims[0], dims[0], 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b156559",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_extended_model_carmichael(df, dataset='shapes3d', vae=None, beta=5, generative_epochs=100, num_ims=1000, latent_dim=10,\n",
    "                        threshold=0.01, shape_id=0):\n",
    "    dims = dims_dict[dataset]\n",
    "\n",
    "    # Split the data_indices instead of the data and labels directly\n",
    "    test_data = df['image']\n",
    "    \n",
    "    # create variant of this dataset with unusual features\n",
    "    test_data_squares = [add_white_square(d, dims) for d in test_data][0:4]\n",
    "    test_data_squares = np.stack(test_data_squares, axis=0).astype(\"float32\")[0:4]\n",
    "\n",
    "    # get reconstructed images and predicted labels for test data\n",
    "    predictions = get_predictions(test_data_squares, vae)\n",
    "    diff = get_true_pred_diff(test_data_squares, predictions)\n",
    "\n",
    "    # create MHN for predictable and unpredictable elements\n",
    "    net = ContinuousHopfield(np.prod(dims) + latent_dim, beta=beta)\n",
    "    # get elements where error is above threshold and remap to range [-1,1]\n",
    "    # first compute the average error across the three channels\n",
    "    avg_diff = np.mean(abs(diff), axis=-1, keepdims=True)\n",
    "    # next create a mask where the average error is greater than the threshold\n",
    "    mask = avg_diff > threshold\n",
    "    # finally create the output image using the mask and input images\n",
    "    sparse_to_encode = np.where(mask, (test_data_squares * 2) - 1, 0)\n",
    "\n",
    "    # get latent vectors for test memories\n",
    "    _, latents = get_recalled_ims_and_latents(vae, test_data_squares, noise_level=0)\n",
    "\n",
    "    # new code\n",
    "    ims_with_diff_shapes = [get_diff_shape_im(df, i, shape_id) for i in range(4)]\n",
    "    recons, pred_labels = get_recalled_ims_and_latents(vae, np.array(ims_with_diff_shapes), noise_level=0)\n",
    "    pred_labels = np.array([pred_labels[0][i] for i in range(len(pred_labels[0]))])\n",
    "    pred_labels = pred_labels.reshape((4, 20, 1))\n",
    "    fixed_latents = pred_labels\n",
    "       \n",
    "    # encode traces in MHN\n",
    "    a = sparse_to_encode.reshape((4, np.prod(dims), 1))\n",
    "    print(a.shape)\n",
    "    print(pred_labels.shape)\n",
    "    hpc_traces = np.concatenate((a, pred_labels), axis=1)\n",
    "    net.learn(hpc_traces[0:4])\n",
    "\n",
    "    # now visualise recall from a noisy image in the MHN\n",
    "    images_masked_np = noise(hpc_traces, noise_factor=0.3, gaussian=True)\n",
    "\n",
    "    tests = []\n",
    "    label_inputs = []\n",
    "    predictions = []\n",
    "    pred_labels = []\n",
    "\n",
    "    for test_ind in range(4):\n",
    "        test_in = images_masked_np[test_ind].reshape(-1, 1)\n",
    "        test_out = net.retrieve(test_in, max_iter=5)\n",
    "        reconstructed = test_out[0:np.prod(dims)]\n",
    "        input_im = test_in[0:np.prod(dims)]\n",
    "\n",
    "        predictions.append(np.array(reconstructed).reshape((1, dims[0], dims[1], dims[2])))\n",
    "        tests.append(np.array(input_im).reshape((1, dims[0], dims[1], dims[2])))\n",
    "        pred_labels.append(test_out[np.prod(dims):])\n",
    "        label_inputs.append(test_in[np.prod(dims):])\n",
    "\n",
    "    predictions = np.concatenate(predictions, axis=0)\n",
    "    tests = np.concatenate(tests, axis=0)\n",
    "    pred_labels = np.array(pred_labels).reshape(4, latent_dim)\n",
    "    label_inputs = np.array(label_inputs).reshape(4, latent_dim)\n",
    "\n",
    "    # for the first ten test images, display the stages of recall\n",
    "    final_outputs = []\n",
    "    for test_ind in range(4):\n",
    "        fig, final_im = recall_memories(test_data_squares, net, vae, dims, latent_dim, test_ind=test_ind,\n",
    "                                        noise_factor=0.1, threshold=threshold, return_final_im=True,\n",
    "                                        fixed_latents=fixed_latents[test_ind])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ba7ce0b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "test_extended_model_carmichael(filtered_df, \n",
    "                               dataset='shapes3d', \n",
    "                               vae=vae, \n",
    "                               latent_dim=20,\n",
    "                               threshold = 0.2,\n",
    "                               shape_id=0)\n",
    "test_extended_model_carmichael(filtered_df, \n",
    "                               dataset='shapes3d',\n",
    "                               vae=vae, \n",
    "                               latent_dim=20,\n",
    "                               threshold = 0.2,\n",
    "                               shape_id=2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
